import numpy as np
import tensorflow as tf

# 构建实验数据
train_x = np.linspace(-1, 1, 100)
# y = 2 * x + b
train_y = 2. * train_x + np.random.randn(*train_x.shape) * 0.3

# 创建模型
# 占位符
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
# 模型参数
weights = tf.Variable(tf.random_normal([1]), name='weights')
biases = tf.Variable(tf.zeros([1]), name='biases')
z = tf.multiply(X, weights) + biases

# 构建损失函数
loss = tf.reduce_mean(tf.square(Y - z))
# 定义学习率
learning_rate = 0.01
# 构建优化函数
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 最小化损失函数
train = optimizer.minimize(loss)

# 初始化所有变量
init = tf.global_variables_initializer()

# 定义 epochs
training_epochs = 20
# 每隔两步显示一次中间值
display_step = 2

# 存放批次值和损失值
plot_data = {'batchsize': [], 'loss': []}


# 定义保存模型对象
saver = tf.train.Saver(max_to_keep=2)
save_dir = 'logs/'  # 生成模型的路径
# 启动Session
with tf.Session() as sess:
    # 初始化全局变量
    sess.run(init)

    # 向模型中 feed 数据
    for epoch in range(training_epochs):
        for (x, y) in zip(train_x, train_y):
            feed_dict = {X: x, Y: y}
            sess.run(train, feed_dict=feed_dict)

        # 显示训练中的数据
        if epoch % display_step == 0:
            loss_ = sess.run(loss, feed_dict={X: train_x, Y: train_y})
            print('epoch:', epoch + 1, 'loss = ', loss_, 'weights=',
                  sess.run(weights), 'biases=', sess.run(biases))
            # 保存检查点
            saver.save(sess, save_dir + 'linearModel.cpkt', global_step=epoch)
    print('Finished...')
    # 保存模型
    saver.save(sess, save_dir+'linerModel.cpkt')  # 如果指定的文件夹不存在会自动创建
    print('loss=', sess.run(loss, feed_dict={X: train_x, Y: train_y}), 'weights=',
          sess.run(weights), 'biases=', sess.run(biases))

# 载入检查点
load_epoch = 18
with tf.Session() as sess_2:
    saver.restore(sess_2, save_dir+'linearModel.cpkt-' + str(load_epoch))
    print('下面是检查点的结果： ')
    print('x=0.2, z=', sess_2.run(z, feed_dict={X: 0.2}))

# trainMonitored
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='logs/ckpt',
                                       save_checkpoint_secs=2) as sess:
    print(sess.run([global_step]))
    while not sess.should_stop():  # 启用死循环，session不停止就不结束
        i = sess.run(step)
        print(i)





